/**
 * Copyright 2020 Huawei Technologies Co., Ltd
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
#include "tf_kernel_builder.h"

#include "util/constant.h"
#include "util/log.h"
#include "util/tf_util.h"
#include "error_code/error_code.h"
#include "ir2tf/ir2tf_parser_factory.h"
#include "runtime/kernel.h"
#include "common/util/error_manager/error_manager.h"
#include "graph/debug/ge_attr_define.h"
#include "graph/utils/graph_utils.h"
#include "runtime/rt_error_codes.h"

using domi::tensorflow::NodeDef;

namespace {
const uint64_t kNum = 2;
static domi::TaskDef g_task_def;
static int64_t g_op_index;
static std::string g_task_info;
static std::shared_ptr<STR_FWK_OP_KERNEL> g_str_fwkop_kernel_ptr;
}

namespace aicpu {
KernelBuilderPtr TfKernelBuilder::instance_ = nullptr;

inline KernelBuilderPtr TfKernelBuilder::Instance() {
  static std::once_flag flag;
  std::call_once(flag, [&]() {
    instance_.reset(new (std::nothrow) TfKernelBuilder);
  });
  return instance_;
}

ge::Status TfKernelBuilder::Initialize() {
  return KernelBuilder::Initialize();
}

/**
 * blocks' size are needed:
 *  1.STR_FWK_OP_KERNEL, this is the task's API struct;
 *  2.InputOuputBuf, defined by protobuf, the struct is KernelRunParam, inside the
 *    struct, input and output buffer's pointer are defined;
 *  3.NodeDefBuf, defined by protobuf, definition is from tensorflow, data is from
 *    GE's graph, need to do the ir transfer;
 *  4.FuncDef, defined by protobuf, for the fused graph.
 */
ge::Status TfKernelBuilder::CalcOpRunningParam(const ge::Node &node) const {
  AICPUE_LOGI("TFKernel's op[%s] run CalcOpRunningParam", node.GetType().c_str());
  ge::OpDescPtr op_desc_ptr = node.GetOpDesc();
  AICPU_CHECK_NOTNULL_ERRCODE(op_desc_ptr, INPUT_PARAM_NULL)

  // check whether the WorkspaceBytes is set
  int64_t workspace_size = 0;
  std::vector<int64_t> workspace_bytes = op_desc_ptr->GetWorkspaceBytes();
  if ((workspace_bytes.empty()) || (workspace_bytes[0] <= 0)) {
    // calc and set WorkspaceBytes
    AICPU_CHECK_RES_WITH_LOG(CalcWorkspaceSize(node, workspace_size),
        "Call TfKernelBuilder::CalcWorkspaceSize function failed, op[%s].",
        node.GetName().c_str())
    op_desc_ptr->SetWorkspaceBytes({workspace_size});
  } else {
    workspace_size = workspace_bytes[0];
    AICPUE_LOGI("Op type[%s] Workspace size already exist, workspace_size is [%lld]",
                node.GetType().c_str(), workspace_size);
  }

  AICPU_CHECK_RES_WITH_LOG(KernelBuilder::SetOutPutsSize(op_desc_ptr),
      "Call KernelBuilder::SetOutPutsSize function failed, op[%s].",
      node.GetName().c_str())
  // Set workspace memory reuse flag
  AICPU_CHECK_FALSE_EXEC(ge::AttrUtils::SetListBool(op_desc_ptr, kWorkspaceReuseFlag, {false}),
      AICPU_REPORT_CALL_ERROR(
          "Call ge::AttrUtils::SetListBool Failed to set attr[%s], op[%s].",
          kWorkspaceReuseFlag.c_str(), node.GetName().c_str());
      return ErrorCode::ADD_ATTR_FAILED)
  AICPUE_LOGI("Op type[%s] Calc the Op running param successfully, workspace_size is [%lld]",
              node.GetType().c_str(), workspace_size);
  return ge::SUCCESS;
}

// Calculation workspace size for node
ge::Status TfKernelBuilder::CalcWorkspaceSize(const ge::Node &node,
                                              int64_t &workspace_size) const {
  ge::OpDescPtr op_desc_ptr = node.GetOpDesc();
  AICPU_CHECK_NOTNULL(op_desc_ptr)
  // Step1 : Calc the InputOuputBuf's size
  FWKAdapter::KernelRunParam kernel_run_param;
  AICPU_CHECK_RES_WITH_LOG(BuildKernelRunParam(*op_desc_ptr, kernel_run_param),
      "Call TfKernelBuilder::BuildKernelRunParam function failed, op[%s].",
      node.GetName().c_str())

  int64_t kernel_run_param_size = static_cast<int64_t>(kernel_run_param.ByteSizeLong());
  AICPUE_LOGI("The kernel_run_param_size size of op type[%s] is [%lld]",
              node.GetType().c_str(), kernel_run_param_size);
  workspace_size = kernel_run_param_size;

  // Step2 : Get the tf's node_def and func_def's definition and size
  ge::GeAttrValue::BYTES node_def_bytes;
  ge::GeAttrValue::BYTES func_def_lib_bytes;
  int64_t node_def_size = 0;
  int64_t func_def_lib_size = 0;
  AICPU_CHECK_RES_WITH_LOG(ParseNodeDefAndFuncDef(node, node_def_bytes, func_def_lib_bytes, node_def_size, func_def_lib_size),
      "Call TfKernelBuilder::ParseNodeDefAndFuncDef function failed, op[%s].",
      node.GetName().c_str())
  // check overflow
  CHECK_INT64_ADD_OVERFLOW(node_def_size, func_def_lib_size,
      ErrorCode::DATA_OVERFLOW,
      "Overflow when calculate total bytes of node def[%ld] and function def"
      " lib[%ld]. op[%s]", node_def_size, func_def_lib_size,
      node.GetName().c_str())
  int64_t node_func_def_size = node_def_size + func_def_lib_size;
  AICPUE_LOGI("The nodeDef and funcDef size is [%lld], op type[%s]",
              node_func_def_size, node.GetType().c_str());

  CHECK_INT64_ADD_OVERFLOW(workspace_size, node_func_def_size,
      ErrorCode::DATA_OVERFLOW,
      "Workspace overflow when add total bytes of kernel ran param[%ld], "
      "node def[%ld] and function def lib[%ld]. op[%s]",
      workspace_size, node_def_size, func_def_lib_size, node.GetName().c_str())
  workspace_size += node_func_def_size;
  return ge::SUCCESS;
}

ge::Status TfKernelBuilder::ParseNodeDefAndFuncDef(const ge::Node &node,
                                                   ge::GeAttrValue::BYTES &node_def_bytes,
                                                   ge::GeAttrValue::BYTES &func_def_lib_bytes,
                                                   int64_t &node_def_size,
                                                   int64_t &func_def_lib_size) const {
  std::string node_name = node.GetName();
  ge::OpDescPtr op_desc = node.GetOpDesc();
  AICPU_CHECK_NOTNULL(op_desc)
  // calculate node def size
  if (!ge::AttrUtils::GetBytes(op_desc, kTfNodeDef, node_def_bytes)) {
    AICPUE_LOGI("Node def attr not exist in ge op[%s], op type[%s].",
                node_name.c_str(), node.GetType().c_str());
    AICPU_CHECK_RES_WITH_LOG(CreateNodeDef(node),
        "Call TfKernelBuilder::CreateNodeDef function failed, op[%s].",
        node.GetName().c_str())
    CHECK_RES_BOOL(ge::AttrUtils::GetBytes(op_desc, kTfNodeDef, node_def_bytes),
        ErrorCode::NODE_DEF_NOT_EXIST,
        AICPU_REPORT_CALL_ERROR(
            "Call ge::AttrUtils::GetBytes failed to get attr[%s], op[%s].",
            kTfNodeDef.c_str(), node_name.c_str()))
  }
  node_def_size = node_def_bytes.GetSize();

  // calculate function def size
  AICPU_CHECK_FALSE_EXEC(ge::AttrUtils::GetBytes(op_desc, kTfFuncDef, func_def_lib_bytes),
                         AICPUE_LOGI("Function def attr is not exist in ge op[%s], op type[%s].",
                         node_name.c_str(), node.GetType().c_str());
                         func_def_lib_size = 0;
                         return ge::SUCCESS)
  func_def_lib_size = func_def_lib_bytes.GetSize();
  return ge::SUCCESS;
}

ge::Status TfKernelBuilder::CreateNodeDef(const ge::Node &node) const {
  std::string node_name = node.GetName();
  ge::OpDescPtr op_desc = node.GetOpDesc();
  std::string op_type = op_desc->GetType();
  AICPU_CHECK_NOTNULL(op_desc)
  // check function op
  if (op_type == kFunctionOp) {
    std::string err_msg = Stringcat( "Can not create node def for function op[",
        node_name, "] in graph compile phase, op type[", op_type, "].");
    AICPU_REPORT_INNER_ERROR("%s.", err_msg.c_str());
    return ErrorCode::NODE_DEF_NOT_EXIST;
  }
  if (op_type == kFrameworkOp) {
    std::string original_type;
    CHECK_RES_BOOL(ge::AttrUtils::GetStr(op_desc, kOriginalType, original_type),
        ErrorCode::GET_ORIGINAL_TYPE_FAILED,
        AICPU_REPORT_CALL_ERROR(
            "Call ge::AttrUtils::GetStr failed to get attr[%s], op[%s].",
            kOriginalType.c_str(), node_name.c_str()))
    op_desc->SetType(original_type);
    op_type = original_type;
  }

  // IR -> tf
  NodeDef node_def;
  std::shared_ptr<Ir2tfBaseParser> parser = Ir2tfParserFactory::Instance().CreateIRParser(op_type);
  if (parser == nullptr) {
    AICPU_REPORT_INNER_ERROR("Create ir parser failed, op[%s], op type[%s].",
        node_name.c_str(), op_type.c_str());
    return ErrorCode::GET_IR2TF_PARSER_FAILED;
  }
  AICPU_CHECK_RES_WITH_LOG(parser->ParseNodeDef(node, &node_def),
      "Call ParseNodeDef function failed, op[%s], op type[%s].",
      node_name.c_str(), op_type.c_str())
  AICPUE_LOGI("Create node_def for ge op[%s] success, op type[%s].",
              node_name.c_str(), op_type.c_str());

  // set tf node_def for ge node
  std::string node_def_string;
  AICPU_CHECK_FALSE_EXEC(node_def.SerializeToString(&node_def_string),
    AICPU_REPORT_INNER_ERROR(
        "Serialize node def to string failed. op[%s], op type[%s].",
        node_name.c_str(), op_type.c_str());
    return ErrorCode::CREATE_NODEDEF_FAILED)

  const uint8_t *node_def_buff = reinterpret_cast<const uint8_t *>(node_def_string.data());
  AICPU_CHECK_FALSE_EXEC(
      ge::AttrUtils::SetZeroCopyBytes(op_desc, kTfNodeDef, ge::Buffer::CopyFrom(node_def_buff, node_def_string.length())),
      AICPU_REPORT_CALL_ERROR("Call ge::AttrUtils::SetZeroCopyBytes failed"
          " for [%s]. op[%s], op type[%s].",
          kTfNodeDef.c_str(), node_name.c_str(), op_type.c_str());
      return ErrorCode::CREATE_NODEDEF_FAILED)
  return ge::SUCCESS;
}

ge::Status TfKernelBuilder::BuildKernelRunParam(const ge::OpDesc &op_desc,
                                                FWKAdapter::KernelRunParam &kernel_run_param,
                                                bool skip_dim_check) const {
  // Construct input's content
  std::vector<int64_t> input_offset = op_desc.GetInputOffset();
  std::set<std::string> refinput_set;
  std::string op_type = op_desc.GetType();
  std::shared_ptr<Ir2tfBaseParser> parser = Ir2tfParserFactory::Instance().CreateIRParser(op_type);
  parser->GetRefInputSet(op_type, refinput_set);

  size_t input_size = op_desc.GetInputsSize();
  aicpu::State state;
  for (size_t i = 0; i < input_size; i++) {
    FWKAdapter::TensorDataInfo *input_tensor = kernel_run_param.add_input();
    ge::GeTensorDesc ge_tensor_desc = op_desc.GetInputDesc(i);
    std::string input_name = op_desc.GetInputNameByIndex(i);
    bool is_ref = false;
    auto iter = refinput_set.find(input_name);
    if (iter != refinput_set.end()) {
      is_ref = true;
    }
    AICPUE_LOGI("Op type[%s], input name[%s], is ref[%d]",
                op_type.c_str(), input_name.c_str(), is_ref);
    state = SetTensorDataInfo(ge_tensor_desc, input_tensor, is_ref);
    if (state.state != ge::SUCCESS) {
      state.msg = Stringcat(i, "th input's ", state.msg,
          ", op[", op_desc.GetName(), "].");
      AICPU_REPORT_INNER_ERROR("%s", state.msg.c_str());
      return state.state;
    }
  }

  // Construct output's content
  size_t output_size = op_desc.GetOutputsSize();
  for (size_t i = 0; i < output_size; i++) {
    FWKAdapter::TensorDataInfo *output_tensor = kernel_run_param.add_output();
    ge::GeTensorDesc ge_tensor_desc = op_desc.GetOutputDesc(i);
    state = SetTensorDataInfo(ge_tensor_desc, output_tensor, false, skip_dim_check, true);
    if (state.state != ge::SUCCESS) {
      state.msg = Stringcat(i, "th output's ",
          state.msg, ", op[", op_desc.GetName(), "].");
      AICPU_REPORT_INNER_ERROR("%s", state.msg.c_str());
      return state.state;
    }
  }
  return ge::SUCCESS;
}

aicpu::State TfKernelBuilder::SetTensorDataInfo(const ge::GeTensorDesc &ge_tensor_desc,
                                                FWKAdapter::TensorDataInfo *tensor_data_info,
                                                bool is_ref,
                                                bool skip_dim_check,
                                                bool is_output) const {
  ge::DataType data_type = ge_tensor_desc.GetDataType();
  uint32_t tf_data_type = static_cast<uint32_t>(ConvertGeDataType2TfDataType(data_type, is_ref));
  tensor_data_info->set_dtype(tf_data_type);

  // Just used to calc the length, so put a fake value, different
  // number will have the different addr, so put the max uint64 value
  tensor_data_info->set_data_addr(ULLONG_MAX);
  std::vector<int64_t> dims;
  if (is_output) {
    std::vector<std::pair<int64_t, int64_t>> shape_range;
    // try to get dims from shape range
    AICPU_CHECK_RES_WITH_LOG(ge_tensor_desc.GetShapeRange(shape_range),
        "Call ge::GeTensorDesc::GetShapeRange function failed.")
    if (!shape_range.empty()) {
      AICPUE_LOGI("Get dims from shape range, dim size: [%zu].", shape_range.size());
      for (const auto &dim_item : shape_range) {
        dims.emplace_back(dim_item.second);
      }
    } else {
      ge::GeShape ge_shape = ge_tensor_desc.GetShape();
      dims = ge_shape.GetDims();
    }
  } else {
    ge::GeShape ge_shape = ge_tensor_desc.GetShape();
    dims = ge_shape.GetDims();
  }

  uint32_t dim_shape = dims.size();
  if (!skip_dim_check) {
    for (uint32_t i = 0; i < dim_shape; i++) {
      bool is_invalid_dim = ((dims[i] < 0) &&
                             (dims[i] != ge::UNKNOWN_DIM) &&
                             (dims[i] != ge::UNKNOWN_DIM_NUM));
      if (is_invalid_dim) {
        std::string err_msg =  Stringcat("dim[", i,
            "] is invalid, shape is [", DebugString(dims), "].");
        aicpu::State state(GE_SHAPE_SIZE_INVAILD, err_msg);
        return state;
      }
      tensor_data_info->add_dim(dims[i]);
    }
  } else {
    AICPUE_LOGI("Skip_dim_check for unknown shape");
  }
  return aicpu::State(ge::SUCCESS);
}

ge::Status TfKernelBuilder::GenerateTask(const ge::Node &node, const ge::RunContext &run_context, std::vector<domi::TaskDef> &tasks) {
  AICPUE_LOGI("TFKernel's op[%s], op type[%s] run GenerateTask. ", node.GetName().c_str(), node.GetType().c_str());
  // Check the input data
  std::shared_ptr<ge::OpDesc> op_desc_ptr = node.GetOpDesc();
  AICPU_CHECK_NOTNULL_ERRCODE(op_desc_ptr, ErrorCode::INPUT_PARAM_NULL)
  AICPU_CHECK_NOTNULL_ERRCODE(run_context.model, ErrorCode::INPUT_PARAM_NULL)
  AICPU_CHECK_NOTNULL_ERRCODE(run_context.stream, ErrorCode::INPUT_PARAM_NULL)
  std::lock_guard<std::mutex> lock(mutex_);
  g_op_index = op_desc_ptr->GetId();

  // rtTaskGenCallback callback.
  rtError_t rt_ret = rtSetTaskGenCallback(GetTaskInfoCallback);
  AICPU_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE,
      AICPUE_LOGEVENT(
          "Call rtSetTaskGenCallback function failed, ret[0x%X]", rt_ret))
  AICPU_CHECK_RES_WITH_LOG(BuildAndLaunchKernel(node, run_context),
      "Call TfKernelBuilder::BuildAndLaunchKernel function failed, op[%s].",
      node.GetName().c_str())
  tasks.emplace_back(g_task_def);

  int32_t shape_type = 0;
  if (ge::AttrUtils::HasAttr(op_desc_ptr, kAttrNameUnknownShape)) {
    CHECK_RES_BOOL(ge::AttrUtils::GetInt(op_desc_ptr, ge::ATTR_NAME_UNKNOWN_SHAPE_TYPE, shape_type),
        INVOKE_GRAPH_ITF_FAILED,
        AICPU_REPORT_CALL_ERROR(
            "Call ge::AttrUtils::GetStr failed to get attr[%s], op[%s].",
            ge::ATTR_NAME_UNKNOWN_SHAPE_TYPE.c_str(), node.GetName().c_str()))
    // unknow shape
    if (shape_type == ge::DEPEND_COMPUTE) {
      // unknow type 4
      STR_FWK_OP_KERNEL task = {0};
      std::string mem_copy_task_info;
      uint64_t data_info_size = static_cast<uint64_t>(op_desc_ptr->GetOutputsSize()) * kNum;
      GenMemCopyTask(data_info_size, task, mem_copy_task_info);

      domi::TaskDef task_def;
      task_def.set_type(RT_MODEL_TASK_KERNEL_EX);
      domi::KernelExDef *kernel_def_ex = task_def.mutable_kernel_ex();
      kernel_def_ex->set_args(reinterpret_cast<void *>(&task), sizeof(STR_FWK_OP_KERNEL));
      kernel_def_ex->set_args_size(sizeof(STR_FWK_OP_KERNEL));
      kernel_def_ex->set_task_info(mem_copy_task_info);
      kernel_def_ex->set_task_info_size(mem_copy_task_info.size());
      kernel_def_ex->set_op_index(g_op_index);
      tasks.emplace_back(task_def);
    }
  }
  return ge::SUCCESS;
}

rtError_t TfKernelBuilder::GetTaskInfoCallback(rtModel_t model, rtTaskInfo_t *task_info) {
  AICPUE_LOGI("Invoke the runtime successfully, begin to get the task_info in the callback.");
  // Verify the task info
  if (task_info == nullptr) {
    AICPU_REPORT_INNER_ERROR("The task_info from AICPU is null.");
    return ACL_ERROR_RT_PARAM_INVALID;
  }

  if (task_info->type != RT_MODEL_TASK_KERNEL_EX) {
    AICPU_REPORT_INNER_ERROR("The task type[%u] is not RT_MODEL_TASK_KERNEL_EX[%d].",
        task_info->type, RT_MODEL_TASK_KERNEL_EX);
    return ACL_ERROR_RT_PARAM_INVALID;
  }
  // Convert the task_info to g_task_def
  g_task_def.set_stream_id(task_info->streamID);
  g_task_def.set_type(task_info->type);
  domi::KernelExDef *kernel_def_ex = g_task_def.mutable_kernel_ex();
  kernel_def_ex->set_args(reinterpret_cast<void *>(g_str_fwkop_kernel_ptr.get()), task_info->u.kernelTaskEx.argsSize);
  kernel_def_ex->set_args_size(task_info->u.kernelTaskEx.argsSize);
  kernel_def_ex->set_task_info(g_task_info);
  kernel_def_ex->set_task_info_size(g_task_info.size());
  kernel_def_ex->set_op_index(g_op_index);
  kernel_def_ex->set_flags(task_info->u.kernelTaskEx.flags);

  g_task_info.clear();
  return RT_ERROR_NONE;
}

/**
 * blocks of memory are needed:
 *  1.STR_FWK_OP_KERNEL, in this struct, the other's pointer is defined;
 *  2.InputOuputBuf, defined by protobuf, inside this struct, input and output buffer's pointer are defined;
 *  3.NodeDefBuf, defined by protobuf, definition is from tensorflow, data is from GE's graph;
 *  4.FuncDef, defined by protobuf, for the fused graph.
 */
ge::Status TfKernelBuilder::BuildAndLaunchKernel(const ge::Node &node, const ge::RunContext &run_context) const {
  ge::OpDescPtr op_desc_ptr = node.GetOpDesc();
  AICPU_CHECK_NOTNULL(op_desc_ptr)
  // Step1 : define the task api struct.
  AICPU_MAKE_SHARED(g_str_fwkop_kernel_ptr = std::make_shared<STR_FWK_OP_KERNEL>(),
      AICPU_REPORT_INNER_ERROR("Create STR_FWK_OP_KERNEL object failed, op[%s]", node.GetName().c_str());
      return ErrorCode::MEMORY_ALLOC_FAILED)
  // Type 0 represent tensorflow
  g_str_fwkop_kernel_ptr->fwkKernelType = FMK_KERNEL_TYPE_TF;
  FWKAdapter::FWKOperateParam *str_tf_kernel = &(g_str_fwkop_kernel_ptr->fwkKernelBase.fwk_kernel);
  str_tf_kernel->opType = FWKAdapter::FWK_ADPT_KERNEL_RUN;
  str_tf_kernel->sessionID = run_context.sessionId;
  str_tf_kernel->stepIDAddr = 0;
  str_tf_kernel->kernelID = GenerateUniqueId();
  str_tf_kernel->extInfoLen = 0;
  str_tf_kernel->extInfoAddr = 0;
  AICPUE_LOGI("Op type[%s] The kernel id is [%llu], session id is [%llu]",
              node.GetType().c_str(), str_tf_kernel->kernelID, str_tf_kernel->sessionID);

  // Step2 : Build the StrFWKKernel
  FWKAdapter::KernelRunParam kernel_run_param;
  AICPU_CHECK_RES(BuildKernelRunParam(*op_desc_ptr, kernel_run_param))
  str_tf_kernel->inputOutputLen = static_cast<int64_t>(kernel_run_param.ByteSizeLong());
  str_tf_kernel->inputOutputBuf = 0;
  AICPUE_LOGI("The kernel_run_param_size size is [%lld], op type[%s]",
              str_tf_kernel->inputOutputLen, node.GetType().c_str());
  // Serialize the kernel_run_param
  std::string kernel_run_param_str;
  CHECK_RES_BOOL(kernel_run_param.SerializeToString(&kernel_run_param_str),
      ErrorCode::SERIALIZE_KERNEL_RUN_PARAM_FAILED,
      AICPU_REPORT_INNER_ERROR("Serialize kernel run param to string failed, op[%s]",
          node.GetName().c_str()))
  g_task_info.append(kernel_run_param_str);

  // Step3~4 : build the tf's node_def and funcDef
  ge::GeAttrValue::BYTES node_def_bytes;
  ge::GeAttrValue::BYTES func_def_lib_bytes;
  int64_t node_def_size = 0;
  int64_t func_def_lib_size = 0;
  AICPU_CHECK_RES_WITH_LOG(ParseNodeDefAndFuncDef(node, node_def_bytes, func_def_lib_bytes, node_def_size, func_def_lib_size),
      "Call TfKernelBuilder::ParseNodeDefAndFuncDef function failed, op[%s].",
      node.GetName().c_str())
  str_tf_kernel->funDefLibLen = 0; // initial value
  str_tf_kernel->nodeDefLen = node_def_size;
  str_tf_kernel->nodeDefBuf = str_tf_kernel->inputOutputBuf + str_tf_kernel->inputOutputLen;
  if (node_def_bytes.GetData() == nullptr) {
    AICPU_REPORT_INNER_ERROR(
        "Append node def to g_task_info failed data, node def is null, op[%s].",
        node.GetName().c_str());
    return INPUT_PARAM_NULL;
  }
  g_task_info.append(reinterpret_cast<const char *>(node_def_bytes.GetData()), node_def_size);

  // Serialize the funcDef
  if (func_def_lib_size > 0 && func_def_lib_bytes.GetData() != nullptr) {
    str_tf_kernel->funDefLibLen = func_def_lib_size;
    str_tf_kernel->funDefLibBuf = str_tf_kernel->nodeDefBuf + str_tf_kernel->nodeDefLen;
    const char *func_def_lib_data = reinterpret_cast<const char *>(func_def_lib_bytes.GetData());
    if (func_def_lib_data == nullptr) {
      AICPU_REPORT_INNER_ERROR("Append function def to g_task_info failed data,"
          " function def is null, op[%s].", node.GetName().c_str());
      return INPUT_PARAM_NULL;
    }
    g_task_info.append(func_def_lib_data, func_def_lib_size);
  }

  // Init inputOutputAddr and workspaceBaseAddr, GE will refresh this value
  str_tf_kernel->inputOutputAddr = 0;
  str_tf_kernel->workspaceBaseAddr = 0;

  // Update the FmkOp info
  AICPU_CHECK_RES(UpdateFmkOpInfo(op_desc_ptr))

  CHECK_UINT64_ADD_OVERFLOW(str_tf_kernel->inputOutputLen,
      str_tf_kernel->nodeDefLen,
      ErrorCode::DATA_OVERFLOW,
      "Overflow occurred when calculate total bytes of input/output info[%lu] and"
      " node def[%lu]. Calculate workspace total bytes failed, op[%s]",
      str_tf_kernel->inputOutputLen, str_tf_kernel->nodeDefLen,
      node.GetName().c_str())
  CHECK_UINT64_ADD_OVERFLOW(str_tf_kernel->inputOutputLen + str_tf_kernel->nodeDefLen,
      str_tf_kernel->funDefLibLen,
      ErrorCode::DATA_OVERFLOW,
      "Overflow occurred when calculate total bytes of input/output info[%lu], node "
      "def[%lu] and function def[%lu]. Calculate workspace total bytes failed, op[%s]",
      str_tf_kernel->inputOutputLen, str_tf_kernel->nodeDefLen,
      str_tf_kernel->funDefLibLen, node.GetName().c_str())

  bool is_unknown_shape = false;
  if (ge::AttrUtils::HasAttr(op_desc_ptr, kAttrNameUnknownShape)) {
    // kAttrNameUnknownShape attr exist, means unknow shape
    is_unknown_shape = true;
  }
  if (!is_unknown_shape) {
    uint64_t workspace_bytes_size = 0;
    AICPU_CHECK_RES_WITH_LOG(
        GetWorkspaceInfo(op_desc_ptr, run_context.dataMemBase,
            workspace_bytes_size),
        "Call KernelBuilder::GetWorkspaceInfo function failed, op[%s].",
            node.GetName().c_str())

    uint64_t min_memory = str_tf_kernel->inputOutputLen +
                         str_tf_kernel->nodeDefLen + str_tf_kernel->funDefLibLen;
    if (workspace_bytes_size < min_memory) {
      AICPU_REPORT_INNER_ERROR(
          "Workspace memory not enough, given[%lu] bytes, expected[%lu] bytes, op[%s]",
          workspace_bytes_size, min_memory, node.GetType().c_str());
      return GE_MEM_NOT_ENOUGH;
    }
  }

  // make and set extend info
  std::vector<char> task_ext_info;
  domi::KernelExDef *kernel_def_ex = g_task_def.mutable_kernel_ex();
  AICPU_CHECK_RES_WITH_LOG(MakeTaskExtInfo(node, task_ext_info),
      "Call TfKernelBuilder::MakeTaskExtInfo function failed, op[%s].",
          node.GetName().c_str())
  if (task_ext_info.size() == 0) {
    str_tf_kernel->extInfoLen = 0;
    kernel_def_ex->clear_kernel_ext_info();
    kernel_def_ex->set_kernel_ext_info_size(0);
  } else {
    str_tf_kernel->extInfoLen = task_ext_info.size();
    kernel_def_ex->set_kernel_ext_info(reinterpret_cast<void *>(task_ext_info.data()), task_ext_info.size());
    kernel_def_ex->set_kernel_ext_info_size(task_ext_info.size());
  }
  AICPUE_LOGI("Node info: unknown shape is [%d], extend info length[%lu], op[%s], op type[%s].",
              is_unknown_shape, str_tf_kernel->extInfoLen, op_desc_ptr->GetName().c_str(), op_desc_ptr->GetType().c_str());
  // Dispatch the task, then get the task_info.
  rtError_t rt_res = rtKernelLaunchEx(
      reinterpret_cast<void *>(g_str_fwkop_kernel_ptr.get()),
      static_cast<uint64_t>(sizeof(STR_FWK_OP_KERNEL)), 0, run_context.stream);
  AICPU_IF_BOOL_EXEC(rt_res != RT_ERROR_NONE,
      AICPU_REPORT_CALL_ERROR(
          "Call rtKernelLaunchEx function failed, op[%s]", node.GetName().c_str());
      return CALL_RT_API_FAILED)



  AICPUE_LOGI("BuildAndLaunchKernel success, kernel id[%llu], op[%s], op type[%s]",
              str_tf_kernel->kernelID, node.GetName().c_str(), node.GetType().c_str());
  return ge::SUCCESS;
}

ge::Status TfKernelBuilder::UpdateFmkOpInfo(std::shared_ptr<ge::OpDesc> &op_desc_ptr) const {
  AICPU_CHECK_NOTNULL(op_desc_ptr)
  std::string original_type = op_desc_ptr->GetType();
  AICPU_CHECK_FALSE_EXEC(ge::AttrUtils::SetStr(op_desc_ptr, kOriginalType, original_type),
      AICPU_REPORT_CALL_ERROR("Call ge::AttrUtils::SetStr failed to set attr[%s], op[%s].",
          kOriginalType.c_str(), op_desc_ptr->GetName().c_str());
      return ErrorCode::ADD_ATTR_FAILED)
  op_desc_ptr->SetType(kFrameworkOp);

  // value 3 represent the framework tensorflow
  AICPU_CHECK_FALSE_EXEC(ge::AttrUtils::SetInt(op_desc_ptr, kFrameworkType, 3),
      AICPU_REPORT_CALL_ERROR(
          "Call ge::AttrUtils::SetInt failed to set attr[%s], op[%s].",
          kFrameworkType.c_str(), op_desc_ptr->GetName().c_str());
      return ErrorCode::ADD_ATTR_FAILED)
  AICPU_CHECK_FALSE_EXEC(ge::AttrUtils::SetInt(op_desc_ptr, ge::ATTR_NAME_IMPLY_TYPE, static_cast<int64_t>(domi::ImplyType::AI_CPU)),
      AICPU_REPORT_CALL_ERROR(
          "Call ge::AttrUtils::SetInt failed to set attr[%s], op[%s].",
          ge::ATTR_NAME_IMPLY_TYPE.c_str(), op_desc_ptr->GetName().c_str());
      return ErrorCode::ADD_ATTR_FAILED)
  return ge::SUCCESS;
}

// Make task extend info for node
ge::Status TfKernelBuilder::MakeTaskExtInfo(const ge::Node &node,
                                            std::vector<char> &task_ext_info) const {
  ge::OpDescPtr op_desc_ptr = node.GetOpDesc();
  AICPU_CHECK_NOTNULL_ERRCODE(op_desc_ptr, ErrorCode::INPUT_PARAM_NULL)
  // op name extend info
  // WARNING: OP NAME MUST BE THE FIRST EXTEND INFO FOR RUNTIME!!!
  ge::Status status = MakeExtInfoForOpName(op_desc_ptr, task_ext_info);
  if (status != ge::SUCCESS) {
    AICPU_REPORT_INNER_ERROR("Call MakeExtInfoForOpName failed, op[%s].",
        op_desc_ptr->GetName().c_str());
    return status;
  }
  // common base extend info
  status = MakeBaseExtInfo(op_desc_ptr, task_ext_info);
  if (status != ge::SUCCESS) {
    AICPU_REPORT_INNER_ERROR("Call MakeBaseExtInfo failed, op[%s], op type[%s].",
        op_desc_ptr->GetName().c_str(), op_desc_ptr->GetType().c_str());
    return status;
  }

  int32_t update_addr_flag = 0;
  if (NeedUpdateAddr(node, update_addr_flag)) {
    AICPUE_LOGI("Set update inputs or outputs address extend info, update_addr_flag[%d], op[%s], op type[%s].",
                update_addr_flag, node.GetName().c_str(), op_desc_ptr->GetType().c_str());
    // add update inputs/outputs address extend info
    uint64_t cur_ext_info_len = task_ext_info.size();
    // value length: sizeof(int32_t)
    uint64_t update_addr_ext_info_len = FWKAdapter::kExtInfoHeadSize + sizeof(int32_t);
    task_ext_info.resize(cur_ext_info_len + update_addr_ext_info_len, 0);
    char *ext_info_buf = task_ext_info.data() + cur_ext_info_len;
    FWKAdapter::ExtInfo *extInfo = reinterpret_cast<FWKAdapter::ExtInfo *>(ext_info_buf);
    extInfo->infoType = FWKAdapter::FWK_ADPT_EXT_UPDATE_ADDR;
    extInfo->infoLen = sizeof(int32_t);
    // set value
    ext_info_buf += FWKAdapter::kExtInfoHeadSize;
    *reinterpret_cast<int32_t *>(ext_info_buf) = update_addr_flag;
  }
  return ge::SUCCESS;
}

// Check the node whether need update inputs/outputs address
bool TfKernelBuilder::NeedUpdateAddr(const ge::Node &node, int32_t &update_addr_flag) const {
  if (IsKnownNodeDynamic(node)) {
    // known node in dynamic shape graph, need update inputs/outputs address
    update_addr_flag = FWKAdapter::FWK_ADPT_UPDATE_INPUT_OUTPUT;
    AICPUE_LOGI("The known shape node in dynamic shape graph, op[%s], op type[%s].",
                node.GetName().c_str(), node.GetType().c_str());
    return true;
  }

  ge::OpDescPtr op_desc_ptr = node.GetOpDesc();
  AICPU_CHECK_NOTNULL_ERRCODE(op_desc_ptr, false)
  // mini inference(zero copy): check the node whether is first or last node
  if (ge::AttrUtils::HasAttr(op_desc_ptr, ge::ATTR_NAME_NODE_CONNECT_INPUT) &&
      ge::AttrUtils::HasAttr(op_desc_ptr, ge::ATTR_NAME_NODE_CONNECT_OUTPUT)) {
    // both first and last node, update inputs/outputs address
    update_addr_flag = FWKAdapter::FWK_ADPT_UPDATE_INPUT_OUTPUT;
    return true;
  }
  if (ge::AttrUtils::HasAttr(op_desc_ptr, ge::ATTR_NAME_NODE_CONNECT_INPUT)) {
    // first node, update inputs address
    update_addr_flag = FWKAdapter::FWK_ADPT_UPDATE_INPUT;
    return true;
  } else if (ge::AttrUtils::HasAttr(op_desc_ptr, ge::ATTR_NAME_NODE_CONNECT_OUTPUT)) {
    // last node, update outputs address
    update_addr_flag = FWKAdapter::FWK_ADPT_UPDATE_OUTPUT;
    return true;
  } else {
    return false;
  }
}

bool TfKernelBuilder::IsKnownNodeDynamic(const ge::Node &node) const {
  ge::OpDescPtr op_desc_ptr = node.GetOpDesc();
  AICPU_CHECK_NOTNULL_ERRCODE(op_desc_ptr, false)
  if (ge::AttrUtils::HasAttr(op_desc_ptr, kAttrNameUnknownShape)) {
    // unknown shape node
    return false;
  }

  auto owner_graph = node.GetOwnerComputeGraph();
  if (owner_graph == nullptr) {
    AICPUE_LOGW("Get null owner compute graph, op[%s], op type[%s].",
                node.GetName().c_str(), node.GetType().c_str());
    return false;
  }
  auto rootGraph = ge::GraphUtils::FindRootGraph(owner_graph);
  if (rootGraph == nullptr) {
    AICPUE_LOGW("Get null root graph, op[%s], op type[%s].",
                node.GetName().c_str(), node.GetType().c_str());
    return false;
  }
  bool is_dynamic = false;
  (void)ge::AttrUtils::GetBool(rootGraph, ge::ATTR_NAME_DYNAMIC_SHAPE_PARTITIONED, is_dynamic);
  return is_dynamic;
}

ge::Status TfKernelBuilder::GenTaskImply(const ge::NodePtr &node,
                                         FWKAdapter::FWKOperateParam *str_tf_kernel,
                                         std::string &task_info,
                                         bool skip_dim_check) {
  // Set default value, these fields may not be used
  str_tf_kernel->sessionID = 0;
  str_tf_kernel->kernelID = 0;

  // for single op run
  str_tf_kernel->opType = FWKAdapter::FWK_ADPT_SINGLE_OP_RUN;

  // Build the KernelRunParam
  FWKAdapter::KernelRunParam kernel_run_param;
  ge::OpDescPtr op_desc_ptr = node->GetOpDesc();
  AICPU_CHECK_RES_WITH_LOG(BuildKernelRunParam(*op_desc_ptr, kernel_run_param, skip_dim_check),
      "Call TfKernelBuilder::BuildKernelRunParam function failed, op[%s]",
          node->GetName().c_str())
  str_tf_kernel->inputOutputLen = static_cast<int64_t>(kernel_run_param.ByteSizeLong());
  str_tf_kernel->inputOutputBuf = 0;
  AICPUE_LOGI("The kernel_run_param_size size is [%lld], op type[%s]",
              str_tf_kernel->inputOutputLen, node->GetType().c_str());
  // Serialize the kernel_run_param
  std::string kernel_run_param_str;
  CHECK_RES_BOOL(kernel_run_param.SerializeToString(&kernel_run_param_str),
      ErrorCode::SERIALIZE_KERNEL_RUN_PARAM_FAILED,
      AICPU_REPORT_INNER_ERROR("Serialize kernel run param to string failed. op[%s].",
          node->GetName().c_str()))
  task_info.append(kernel_run_param_str);

  // Build the tf's nodeDef and funcDef
  ge::GeAttrValue::BYTES node_def_bytes;
  ge::GeAttrValue::BYTES func_def_lib_bytes;
  int64_t node_def_size = 0;
  int64_t func_def_lib_size = 0;
  // Serialize the nodeDef
  AICPU_CHECK_RES_WITH_LOG(ParseNodeDefAndFuncDef(*node, node_def_bytes, func_def_lib_bytes, node_def_size, func_def_lib_size),
      "Call TfKernelBuilder::ParseNodeDefAndFuncDef function failed, op[%s].",
          node->GetName().c_str())
  str_tf_kernel->funDefLibLen = 0; // initial value
  str_tf_kernel->nodeDefLen = static_cast<uint64_t>(node_def_size);
  str_tf_kernel->nodeDefBuf = str_tf_kernel->inputOutputBuf + str_tf_kernel->inputOutputLen;
  if (node_def_bytes.GetData() == nullptr) {
    AICPU_REPORT_INNER_ERROR(
        "Append node def to task_info falied, node def is null, op[%s].",
        node->GetName().c_str());
    return INPUT_PARAM_NULL;
  }
  task_info.append(reinterpret_cast<const char *>(node_def_bytes.GetData()), node_def_size);
  // Serialize the funcDef
  if (func_def_lib_size > 0 && func_def_lib_bytes.GetData() != nullptr) {
    str_tf_kernel->funDefLibLen = func_def_lib_size;
    str_tf_kernel->funDefLibBuf = str_tf_kernel->nodeDefBuf + str_tf_kernel->nodeDefLen;
    const char *func_def_lib_data = reinterpret_cast<const char *>(func_def_lib_bytes.GetData());
    if (func_def_lib_data == nullptr) {
      AICPU_REPORT_INNER_ERROR(
          "Append function def to task_info failed, function def if null, op[%s]",
          node->GetType().c_str());
      return INPUT_PARAM_NULL;
    }
    task_info.append(func_def_lib_data, func_def_lib_size);
  }

  // Update the FmkOp info
  AICPU_CHECK_RES(UpdateFmkOpInfo(op_desc_ptr))
  CHECK_UINT64_ADD_OVERFLOW(str_tf_kernel->inputOutputLen,
      str_tf_kernel->nodeDefLen,
      ErrorCode::DATA_OVERFLOW,
      "Overflow occurred when calculate total bytes of input/output info[%lu] and"
      " node def[%lu]. Calculate workspace total bytes failed, op[%s]",
      str_tf_kernel->inputOutputLen, str_tf_kernel->nodeDefLen,
      node->GetName().c_str())
  CHECK_UINT64_ADD_OVERFLOW(str_tf_kernel->inputOutputLen + str_tf_kernel->nodeDefLen,
      str_tf_kernel->funDefLibLen,
      ErrorCode::DATA_OVERFLOW,
      "Overflow occurred when calculate total bytes of input/output info[%lu], node "
      "def[%lu] and function def[%lu]. Calculate workspace total bytes failed, op[%s]",
      str_tf_kernel->inputOutputLen, str_tf_kernel->nodeDefLen,
      str_tf_kernel->funDefLibLen, node->GetName().c_str())

  // disable ext info
  str_tf_kernel->extInfoLen = 0;
  str_tf_kernel->extInfoAddr = 0;
  return ge::SUCCESS;
}

ge::Status TfKernelBuilder::GenSingleOpRunTask(const ge::NodePtr &node, STR_FWK_OP_KERNEL &task, std::string &task_info) {
  AICPUE_LOGI("Op[%s], op type[%s] start GenSingleOpRunTask", node->GetName().c_str(), node->GetType().c_str());
  task.fwkKernelType = FMK_KERNEL_TYPE_TF;
  FWKAdapter::FWKOperateParam *str_tf_kernel = &(task.fwkKernelBase.fwk_kernel);
  // Build the str_tf_kernel
  AICPU_CHECK_RES_WITH_LOG(GenTaskImply(node, str_tf_kernel, task_info, true),
        "Call TfKernelBuilder::GenTaskImply function failed, op[%s].",
        node->GetName().c_str())
  return ge::SUCCESS;
}

ge::Status TfKernelBuilder::GenMemCopyTask(uint64_t data_info_size,
                                           STR_FWK_OP_KERNEL &task,
                                           std::string &task_info) {
  task.fwkKernelType = FMK_KERNEL_TYPE_TF;
  FWKAdapter::FWKOperateParam *str_tf_kernel = &(task.fwkKernelBase.fwk_kernel);
  // Build Ge node
  static int copy_count = 0;
  std::string node_type("MemCopy");
  // memCopy has four inputs and zero outputs
  int in_count = 4;
  int out_count = 0;
  ge::Format format = ge::FORMAT_NCHW;
  // DT_UINT64 is the data type of the element in the struct DataPtrInfo
  ge::DataType data_type = ge::DT_UINT64;
  std::vector<int64_t> shape = {};
  shape.push_back(data_info_size);
  std::string node_name(node_type + "_" + std::to_string(copy_count));
  ge::NodePtr node = aicpu::GenGeNode(node_name, node_type, in_count, out_count, format, data_type, shape);
  AICPU_CHECK_NOTNULL(node);
  auto op_desc = node->GetOpDesc();
  AICPU_CHECK_FALSE_EXEC(ge::AttrUtils::SetInt(op_desc, "num", data_info_size),
      AICPU_REPORT_CALL_ERROR(
          "Call ge::AttrUtils::SetInt failed to set attr[num], op[%s].",
          node_name.c_str());
      return ErrorCode::ADD_ATTR_FAILED)
  AICPUE_LOGI("Op[%s], op type[%s] start GenMemCopyTask", node->GetName().c_str(), node->GetType().c_str());
  // Build the str_tf_kernel
  AICPU_CHECK_RES_WITH_LOG(GenTaskImply(node, str_tf_kernel, task_info),
        "Call TfKernelBuilder::GenTaskImply function failed, op[%s].",
        node->GetName().c_str())
  return ge::SUCCESS;
}

void TfKernelBuilder::GetInOutPutsDataType(const ge::OpDescPtr &op_desc_ptr,
                                           std::vector<uint32_t> &inputs_type,
                                           std::vector<uint32_t> &outputs_type) const {
  std::set<std::string>  refinput_set;
  std::string op_type = op_desc_ptr->GetType();
  auto parser = Ir2tfParserFactory::Instance().CreateIRParser(op_type);
  parser->GetRefInputSet(op_type, refinput_set);
  size_t input_size = op_desc_ptr->GetInputsSize();
  for (size_t index = 0; index < input_size; index++) {
    ge::GeTensorDesc tensor_desc = op_desc_ptr->GetInputDesc(index);
    std::string input_name = op_desc_ptr->GetInputNameByIndex(index);
    bool is_ref = false;
    auto iter = refinput_set.find(input_name);
    if (iter != refinput_set.end()) {
      is_ref = true;
    }
    ge::DataType data_type = tensor_desc.GetDataType();
    inputs_type.push_back(static_cast<uint32_t>(ConvertGeDataType2TfDataType(data_type, is_ref)));
  }
  size_t outSize = op_desc_ptr->GetOutputsSize();
  for (size_t index = 0; index < outSize; index++) {
    ge::GeTensorDesc tensor_desc = op_desc_ptr->GetOutputDesc(index);
    ge::DataType data_type = tensor_desc.GetDataType();
    outputs_type.push_back(static_cast<uint32_t>(ConvertGeDataType2TfDataType(data_type, false)));
  }
}
FACTORY_KERNEL_BUILDER_CLASS_KEY(TfKernelBuilder, "TFBuilder")
} // namespace aicpu